{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using DALI in PyTorch\n",
    "\n",
    "### Overview\n",
    "\n",
    "This example shows how to use DALI in PyTorch.\n",
    "\n",
    "This example uses readers.Caffe.\n",
    "See other [examples](../../index.rst) for details on how to use different data formats."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us start from defining some global constants\n",
    "\n",
    "`DALI_EXTRA_PATH` environment variable should point to the place where data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded. Please make sure that the proper release tag is checked out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "import subprocess\n",
    "\n",
    "test_data_root = os.environ[\"DALI_EXTRA_PATH\"]\n",
    "\n",
    "# Caffe LMDB\n",
    "lmdb_folder = os.path.join(test_data_root, \"db\", \"lmdb\")\n",
    "\n",
    "res = subprocess.run([\"nvidia-smi\", \"-L\"], stdout=subprocess.PIPE, text=True)\n",
    "N = res.stdout.count(\"\\n\")  # number of GPUs\n",
    "BATCH_SIZE = 128  # batch size per GPU\n",
    "ITERATIONS = 32\n",
    "IMAGE_SIZE = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us define a pipeline with a reader:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvidia.dali import pipeline_def, Pipeline\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "\n",
    "@pipeline_def\n",
    "def caffe_pipeline(num_gpus):\n",
    "    device_id = Pipeline.current().device_id\n",
    "    jpegs, labels = fn.readers.caffe(\n",
    "        name=\"Reader\",\n",
    "        path=lmdb_folder,\n",
    "        random_shuffle=True,\n",
    "        shard_id=device_id,\n",
    "        num_shards=num_gpus,\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\")\n",
    "    images = fn.resize(\n",
    "        images,\n",
    "        resize_shorter=fn.random.uniform(range=(256, 480)),\n",
    "        interp_type=types.INTERP_LINEAR,\n",
    "    )\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images,\n",
    "        crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),\n",
    "        crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),\n",
    "        dtype=types.FLOAT,\n",
    "        crop=(227, 227),\n",
    "        mean=[128.0, 128.0, 128.0],\n",
    "        std=[1.0, 1.0, 1.0],\n",
    "    )\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us create the pipeline and pass it to PyTorch generic iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OK\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from nvidia.dali.plugin.pytorch import DALIGenericIterator\n",
    "\n",
    "\n",
    "label_range = (0, 999)\n",
    "pipes = [\n",
    "    caffe_pipeline(\n",
    "        batch_size=BATCH_SIZE, num_threads=2, device_id=device_id, num_gpus=N\n",
    "    )\n",
    "    for device_id in range(N)\n",
    "]\n",
    "\n",
    "for pipe in pipes:\n",
    "    pipe.build()\n",
    "\n",
    "dali_iter = DALIGenericIterator(pipes, [\"data\", \"label\"], reader_name=\"Reader\")\n",
    "\n",
    "for i, data in enumerate(dali_iter):\n",
    "    # Testing correctness of labels\n",
    "    for d in data:\n",
    "        label = d[\"label\"]\n",
    "        image = d[\"data\"]\n",
    "        ## labels need to be integers\n",
    "        assert np.equal(np.mod(label, 1), 0).all()\n",
    "        ## labels need to be in range pipe_name[2]\n",
    "        assert (label >= label_range[0]).all()\n",
    "        assert (label <= label_range[1]).all()\n",
    "\n",
    "print(\"OK\")"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
